【LLM Review】Decrypt Ptuning - 2023W32

参考论文:P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks

1. 总结

ptuning 是介于prompt engineer和Full FT之间的一种方案,目的是希望用少量训练成本,优化Handcrafted prompt效果,接近Full FT效果,具体效果和model size和task有关,整体看有一定实践价值。

2. ptuning、prompt tuning和prefix tuning的区别

ptuning v1即prompt tuning,prompt virtual token 加入位置不固定且一般为一个token,只在Embedding labyer加 。v2类似prefix tuning ,但部分task去掉了reparameterized,还有一些其他trick,见table1。

Pasted image 20230803222836

3. ptuning v2 实现代码

class RobertaPrefixForTokenClassification(RobertaPreTrainedModel):

# deep prompt tuning
def get_prompt(self, batch_size):
	prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
	past_key_values = self.prefix_encoder(prefix_tokens)
	past_key_values = past_key_values.view(
		batch_size,
		self.pre_seq_len,
		self.n_layer * 2, 
		self.n_head,
		self.n_embd
	)
	past_key_values = self.dropout(past_key_values)
	past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
	return past_key_values

# reparameterized encoder(MLP)
class PrefixEncoder(torch.nn.Module):
    r'''
    The torch.nn model to encode the prefix

    Input shape: (batch-size, prefix-length)

    Output shape: (batch-size, prefix-length, 2*layers*hidden)
    '''
    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
            )
        else:
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values